{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from numpy import ndarray"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def assert_same_shape(output: ndarray, \n",
    "                      output_grad: ndarray):\n",
    "    assert output.shape == output_grad.shape, \\\n",
    "    '''\n",
    "    Two ndarray should have the same shape; instead, first ndarray's shape is {0}\n",
    "    and second ndarray's shape is {1}.\n",
    "    '''.format(tuple(output_grad.shape), tuple(output.shape))\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def assert_dim(t: ndarray,\n",
    "               dim: ndarray):\n",
    "    assert len(t.shape) == dim, \\\n",
    "    '''\n",
    "    Tensor expected to have dimension {0}, instead has dimension {1}\n",
    "    '''.format(dim, len(t.shape))\n",
    "    return None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1D Convolution"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1 input, 1 output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Padding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_1d = np.array([1,2,3,4,5])\n",
    "param_1d = np.array([1,1,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _pad_1d(inp: ndarray,\n",
    "            num: int) -> ndarray:\n",
    "    z = np.array([0])\n",
    "    z = np.repeat(z, num)\n",
    "    return np.concatenate([z, inp, z])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 2, 3, 4, 5, 0])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_pad_1d(input_1d, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Forward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conv_1d(inp: ndarray, \n",
    "            param: ndarray) -> ndarray:\n",
    "    \n",
    "    # assert correct dimensions\n",
    "    assert_dim(inp, 1)\n",
    "    assert_dim(param, 1)\n",
    "    \n",
    "    # pad the input\n",
    "    param_len = param.shape[0]\n",
    "    param_mid = param_len // 2\n",
    "    inp_pad = _pad_1d(inp, param_mid)\n",
    "    \n",
    "    # initialize the output\n",
    "    out = np.zeros(inp.shape)\n",
    "    \n",
    "    # perform the 1d convolution\n",
    "    for o in range(out.shape[0]):\n",
    "        for p in range(param_len):\n",
    "            out[o] += param[p] * inp_pad[o+p]\n",
    "\n",
    "    # ensure shapes didn't change            \n",
    "    assert_same_shape(inp, out)\n",
    "\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conv_1d_sum(inp: ndarray, \n",
    "                param: ndarray) -> ndarray:\n",
    "    out = conv_1d(inp, param)\n",
    "    return np.sum(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "39.0"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conv_1d_sum(input_1d, param_1d)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Testing gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4\n",
      "0\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(190220)\n",
    "print(np.random.randint(0, input_1d.shape[0]))\n",
    "print(np.random.randint(0, param_1d.shape[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_1d_2 = np.array([1,2,3,4,6])\n",
    "param_1d = np.array([1,1,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n"
     ]
    }
   ],
   "source": [
    "print(conv_1d_sum(input_1d_2, param_1d) - conv_1d_sum(input_1d_2, param_1d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10.0\n"
     ]
    }
   ],
   "source": [
    "input_1d = np.array([1,2,3,4,5])\n",
    "param_1d_2 = np.array([2,1,1])\n",
    "\n",
    "print(conv_1d_sum(input_1d, param_1d_2) - conv_1d_sum(input_1d, param_1d))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _param_grad_1d(inp: ndarray, \n",
    "                   param: ndarray, \n",
    "                   output_grad: ndarray = None) -> ndarray:\n",
    "    \n",
    "    param_len = param.shape[0]\n",
    "    param_mid = param_len // 2\n",
    "    input_pad = _pad_1d(inp, param_mid)\n",
    "    \n",
    "    if output_grad is None:\n",
    "        output_grad = np.ones_like(inp)\n",
    "    else:\n",
    "        assert_same_shape(inp, output_grad)\n",
    "\n",
    "    # Zero padded 1 dimensional convolution\n",
    "    param_grad = np.zeros_like(param)\n",
    "    input_grad = np.zeros_like(inp)\n",
    "\n",
    "    for o in range(inp.shape[0]):\n",
    "        for p in range(param.shape[0]):\n",
    "            param_grad[p] += input_pad[o+p] * output_grad[o]\n",
    "        \n",
    "    assert_same_shape(param_grad, param)\n",
    "    \n",
    "    return param_grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _input_grad_1d(inp: ndarray, \n",
    "                   param: ndarray, \n",
    "                   output_grad: ndarray = None) -> ndarray:\n",
    "    \n",
    "    param_len = param.shape[0]\n",
    "    param_mid = param_len // 2\n",
    "    inp_pad = _pad_1d(inp, param_mid)\n",
    "    \n",
    "    if output_grad is None:\n",
    "        output_grad = np.ones_like(inp)\n",
    "    else:\n",
    "        assert_same_shape(inp, output_grad)\n",
    "    \n",
    "    output_pad = _pad_1d(output_grad, param_mid)\n",
    "    \n",
    "    # Zero padded 1 dimensional convolution\n",
    "    param_grad = np.zeros_like(param)\n",
    "    input_grad = np.zeros_like(inp)\n",
    "\n",
    "    for o in range(inp.shape[0]):\n",
    "        for f in range(param.shape[0]):\n",
    "            input_grad[o] += output_pad[o+param_len-f-1] * param[f]\n",
    "        \n",
    "    assert_same_shape(param_grad, param)\n",
    "    \n",
    "    return input_grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2, 3, 3, 3, 2])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_input_grad_1d(input_1d, param_1d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([10, 15, 14])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_param_grad_1d(input_1d, param_1d)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Works!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Batch size of 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_1d_batch = np.array([[0,1,2,3,4,5,6], \n",
    "                           [1,2,3,4,5,6,7]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _pad_1d(inp: ndarray,\n",
    "            num: int) -> ndarray:\n",
    "    z = np.array([0])\n",
    "    z = np.repeat(z, num)\n",
    "    return np.concatenate([z, inp, z])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _pad_1d_batch(inp: ndarray, \n",
    "                  num: int) -> ndarray:\n",
    "    outs = [_pad_1d(obs, num) for obs in inp]\n",
    "    return np.stack(outs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0, 0, 1, 2, 3, 4, 5, 6, 0],\n",
       "       [0, 1, 2, 3, 4, 5, 6, 7, 0]])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_pad_1d_batch(input_1d_batch, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Forward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conv_1d_batch(inp: ndarray, \n",
    "                  param: ndarray) -> ndarray:\n",
    "\n",
    "    outs = [conv_1d(obs, param) for obs in inp]\n",
    "    return np.stack(outs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 1.,  3.,  6.,  9., 12., 15., 11.],\n",
       "       [ 3.,  6.,  9., 12., 15., 18., 13.]])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conv_1d_batch(input_1d_batch, param_1d)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gradient"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def input_grad_1d_batch(inp: ndarray, \n",
    "                        param: ndarray) -> ndarray:\n",
    "\n",
    "    out = conv_1d_batch(inp, param)\n",
    "    \n",
    "    out_grad = np.ones_like(out)\n",
    "    \n",
    "    batch_size = out_grad.shape[0]\n",
    "        \n",
    "    grads = [_input_grad_1d(inp[i], param, out_grad[i]) for i in range(batch_size)]    \n",
    "\n",
    "    return np.stack(grads)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def param_grad_1d_batch(inp: ndarray, \n",
    "                        param: ndarray) -> ndarray:\n",
    "\n",
    "    output_grad = np.ones_like(inp)\n",
    "    \n",
    "    inp_pad = _pad_1d_batch(inp, 1)\n",
    "    out_pad = _pad_1d_batch(inp, 1)\n",
    "\n",
    "    param_grad = np.zeros_like(param)    \n",
    "    \n",
    "    for i in range(inp.shape[0]):\n",
    "        for o in range(inp.shape[1]):\n",
    "            for p in range(param.shape[0]):\n",
    "                param_grad[p] += inp_pad[i][o+p] * output_grad[i][o]    \n",
    "\n",
    "    return param_grad"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Checking gradients for `conv_1d_batch`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conv_1d_batch_sum(inp: ndarray, \n",
    "                      fil: ndarray) -> ndarray:\n",
    "    out = conv_1d_batch(inp, fil)\n",
    "    return np.sum(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "133.0"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conv_1d_batch_sum(input_1d_batch, param_1d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "2\n"
     ]
    }
   ],
   "source": [
    "print(np.random.randint(0, input_1d_batch.shape[0]))\n",
    "print(np.random.randint(0, input_1d_batch.shape[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3.0"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_1d_batch_2 = input_1d_batch.copy()\n",
    "input_1d_batch_2[0][2] += 1\n",
    "conv_1d_batch_sum(input_1d_batch_2, param_1d) - conv_1d_batch_sum(input_1d_batch, param_1d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[2, 3, 3, 3, 3, 3, 2],\n",
       "       [2, 3, 3, 3, 3, 3, 2]])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_grad_1d_batch(input_1d_batch, param_1d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n"
     ]
    }
   ],
   "source": [
    "print(np.random.randint(0, param_1d.shape[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "48.0"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "param_1d_2 = param_1d.copy()\n",
    "param_1d_2[2] += 1\n",
    "conv_1d_batch_sum(input_1d_batch, param_1d_2) - conv_1d_batch_sum(input_1d_batch, param_1d) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([36, 49, 48])"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "param_grad_1d_batch(input_1d_batch, param_1d)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2D Convolutions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "imgs_2d_batch = np.random.randn(3, 28, 28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "param_2d = np.random.randn(3, 3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Padding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _pad_2d(inp: ndarray, \n",
    "            num: int):\n",
    "    '''\n",
    "    Input is a 3 dimensional tensor, first dimension batch size\n",
    "    '''\n",
    "    outs = [_pad_2d_obs(obs, num) for obs in inp]\n",
    "\n",
    "    return np.stack(outs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _pad_2d_obs(inp: ndarray, \n",
    "                num: int):\n",
    "    '''\n",
    "    Input is a 2 dimensional, square, 2D Tensor\n",
    "    '''\n",
    "    inp_pad = _pad_1d_batch(inp, num)\n",
    "\n",
    "    other = np.zeros((num, inp.shape[0] + num * 2))\n",
    "\n",
    "    return np.concatenate([other, inp_pad, other])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3, 30, 30)"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_pad_2d(imgs_2d_batch, 1).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compute output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _compute_output_obs_2d(obs: ndarray, \n",
    "                           param: ndarray):\n",
    "    '''\n",
    "    Obs is a 2d square Tensor, so is param\n",
    "    '''\n",
    "    param_mid = param.shape[0] // 2\n",
    "    \n",
    "    obs_pad = _pad_2d_obs(obs, param_mid)\n",
    "    \n",
    "    out = np.zeros_like(obs)\n",
    "    \n",
    "    for o_w in range(out.shape[0]):\n",
    "        for o_h in range(out.shape[1]):\n",
    "            for p_w in range(param.shape[0]):\n",
    "                for p_h in range(param.shape[1]):\n",
    "                    out[o_w][o_h] += param[p_w][p_h] * obs_pad[o_w+p_w][o_h+p_h]\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _compute_output_2d(img_batch: ndarray,\n",
    "                       param: ndarray):\n",
    "    \n",
    "    assert_dim(img_batch, 3)\n",
    "    \n",
    "    outs = [_compute_output_obs_2d(obs, param) for obs in img_batch]\n",
    "    \n",
    "    return np.stack(outs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3, 28, 28)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_compute_output_2d(imgs_2d_batch, param_2d).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Param grads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _compute_grads_obs_2d(input_obs: ndarray,\n",
    "                          output_grad_obs: ndarray, \n",
    "                          param: ndarray) -> ndarray:\n",
    "    '''\n",
    "    input_obs: 2D Tensor representing the input observation\n",
    "    output_grad_obs: 2D Tensor representing the output gradient  \n",
    "    param: 2D filter\n",
    "    '''\n",
    "    \n",
    "    param_size = param.shape[0]\n",
    "    output_obs_pad = _pad_2d_obs(output_grad_obs, param_size // 2)\n",
    "    input_grad = np.zeros_like(input_obs)\n",
    "\n",
    "    for i_w in range(input_obs.shape[0]):\n",
    "        for i_h in range(input_obs.shape[1]):\n",
    "            for p_w in range(param_size):\n",
    "                for p_h in range(param_size):\n",
    "                    input_grad[i_w][i_h] += output_obs_pad[i_w+param_size-p_w-1][i_h+param_size-p_h-1] \\\n",
    "                    * param[p_w][p_h]\n",
    "\n",
    "    return input_grad\n",
    "\n",
    "def _compute_grads_2d(inp: ndarray,\n",
    "                      output_grad: ndarray, \n",
    "                      param: ndarray) -> ndarray:\n",
    "\n",
    "    grads = [_compute_grads_obs_2d(inp[i], output_grad[i], param) for i in range(output_grad.shape[0])]    \n",
    "\n",
    "    return np.stack(grads)\n",
    "\n",
    "\n",
    "def _param_grad_2d(inp: ndarray,\n",
    "                   output_grad: ndarray, \n",
    "                   param: ndarray) -> ndarray:\n",
    "\n",
    "    param_size = param.shape[0]\n",
    "    inp_pad = _pad_2d(inp, param_size // 2)\n",
    "\n",
    "    param_grad = np.zeros_like(param)\n",
    "    img_shape = output_grad.shape[1:]\n",
    "    \n",
    "    for i in range(inp.shape[0]):\n",
    "        for o_w in range(img_shape[0]):\n",
    "            for o_h in range(img_shape[1]):\n",
    "                for p_w in range(param_size):\n",
    "                    for p_h in range(param_size):\n",
    "                        param_grad[p_w][p_h] += inp_pad[i][o_w+p_w][o_h+p_h] \\\n",
    "                        * output_grad[i][o_w][o_h]\n",
    "    return param_grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_grads = _compute_grads_2d(imgs_2d_batch, \n",
    "                              np.ones_like(imgs_2d_batch),\n",
    "                              param_2d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3, 28, 28)"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img_grads.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3, 3)"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "param_grad = _param_grad_2d(imgs_2d_batch, \n",
    "                              np.ones_like(imgs_2d_batch),\n",
    "                              param_2d)\n",
    "param_grad.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Testing gradients"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "6\n",
      "18\n"
     ]
    }
   ],
   "source": [
    "print(np.random.randint(0, imgs_2d_batch.shape[0]))\n",
    "print(np.random.randint(0, imgs_2d_batch.shape[1]))\n",
    "print(np.random.randint(0, imgs_2d_batch.shape[2]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "imgs_2d_batch_2 = imgs_2d_batch.copy()\n",
    "imgs_2d_batch_2[0][6][18] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _compute_output_2d_sum(img_batch: ndarray,\n",
    "                           param: ndarray):\n",
    "    \n",
    "    out = _compute_output_2d(img_batch, param)\n",
    "    \n",
    "    return out.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-3.1843477398599163"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_compute_output_2d_sum(imgs_2d_batch_2, param_2d) - \\\n",
    "_compute_output_2d_sum(imgs_2d_batch, param_2d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-3.184347739859924"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img_grads[0][6][18]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Param"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "2\n"
     ]
    }
   ],
   "source": [
    "print(np.random.randint(0, param_2d.shape[0]))\n",
    "print(np.random.randint(0, param_2d.shape[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "param_2d_2 = param_2d.copy()\n",
    "param_2d_2[0][2] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5.53349015923007"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_compute_output_2d_sum(imgs_2d_batch, param_2d_2) - _compute_output_2d_sum(imgs_2d_batch, param_2d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5.533490159230001"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "param_grad[0][2]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## With channels + batch size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Helper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _pad_2d_channel(inp: ndarray, \n",
    "                    num: int):\n",
    "    '''\n",
    "    inp has dimension [num_channels, image_width, image_height] \n",
    "    '''\n",
    "    return np.stack([_pad_2d_obs(channel, num) for channel in inp])\n",
    "\n",
    "def _pad_conv_input(inp: ndarray,\n",
    "                    num: int):   \n",
    "    '''\n",
    "    inp has dimension [batch_size, num_channels, image_width, image_height]\n",
    "    '''    \n",
    "    return np.stack([_pad_2d_channel(obs, num) for obs in inp])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Forward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _compute_output_obs(obs: ndarray, \n",
    "                        param: ndarray):\n",
    "    '''\n",
    "    obs: [channels, img_width, img_height]\n",
    "    param: [in_channels, out_channels, fil_width, fil_height]    \n",
    "    '''\n",
    "    assert_dim(obs, 3)\n",
    "    assert_dim(param, 4)\n",
    "    \n",
    "    param_size = param.shape[2]\n",
    "    param_mid = param_size // 2\n",
    "    obs_pad = _pad_2d_channel(obs, param_mid)\n",
    "    \n",
    "    in_channels = param.shape[0]\n",
    "    out_channels = param.shape[1]\n",
    "    img_size = obs.shape[1]\n",
    "    \n",
    "    out = np.zeros((out_channels,) + obs.shape[1:])\n",
    "    for c_in in range(in_channels):\n",
    "        for c_out in range(out_channels):\n",
    "            for o_w in range(img_size):\n",
    "                for o_h in range(img_size):\n",
    "                    for p_w in range(param_size):\n",
    "                        for p_h in range(param_size):\n",
    "                            out[c_out][o_w][o_h] += \\\n",
    "                            param[c_in][c_out][p_w][p_h] * obs_pad[c_in][o_w+p_w][o_h+p_h]\n",
    "    return out    \n",
    "\n",
    "def _output(inp: ndarray,\n",
    "                    param: ndarray) -> ndarray:\n",
    "    '''\n",
    "    obs: [batch_size, channels, img_width, img_height]\n",
    "    fil: [in_channels, out_channels, fil_width, fil_height]    \n",
    "    '''\n",
    "    outs = [_compute_output_obs(obs, param) for obs in inp]    \n",
    "\n",
    "    return np.stack(outs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Backward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _compute_grads_obs(input_obs: ndarray,\n",
    "                       output_grad_obs: ndarray,\n",
    "                       param: ndarray) -> ndarray:\n",
    "    '''\n",
    "    input_obs: [in_channels, img_width, img_height]\n",
    "    output_grad_obs: [out_channels, img_width, img_height]\n",
    "    param: [in_channels, out_channels, img_width, img_height]    \n",
    "    '''\n",
    "    input_grad = np.zeros_like(input_obs)    \n",
    "    param_size = param.shape[2]\n",
    "    param_mid = param_size // 2\n",
    "    img_size = input_obs.shape[1]\n",
    "    in_channels = input_obs.shape[0]\n",
    "    out_channels = param.shape[1]\n",
    "    output_obs_pad = _pad_2d_channel(output_grad_obs, param_mid)\n",
    "    \n",
    "    for c_in in range(in_channels):\n",
    "        for c_out in range(out_channels):\n",
    "            for i_w in range(input_obs.shape[1]):\n",
    "                for i_h in range(input_obs.shape[2]):\n",
    "                    for p_w in range(param_size):\n",
    "                        for p_h in range(param_size):\n",
    "                            input_grad[c_in][i_w][i_h] += \\\n",
    "                            output_obs_pad[c_out][i_w+param_size-p_w-1][i_h+param_size-p_h-1] \\\n",
    "                            * param[c_in][c_out][p_w][p_h]\n",
    "    return input_grad\n",
    "\n",
    "def _input_grad(inp: ndarray,\n",
    "                output_grad: ndarray, \n",
    "                param: ndarray) -> ndarray:\n",
    "\n",
    "    grads = [_compute_grads_obs(inp[i], output_grad[i], param) for i in range(output_grad.shape[0])]    \n",
    "\n",
    "    return np.stack(grads)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _param_grad(inp: ndarray,\n",
    "                output_grad: ndarray, \n",
    "                param: ndarray) -> ndarray:\n",
    "    '''\n",
    "    inp: [in_channels, img_width, img_height]\n",
    "    output_grad_obs: [out_channels, img_width, img_height]\n",
    "    param: [in_channels, out_channels, img_width, img_height]    \n",
    "    '''\n",
    "    param_grad = np.zeros_like(param)    \n",
    "    param_size = param.shape[2]\n",
    "    param_mid = param_size // 2\n",
    "    img_size = inp.shape[2]\n",
    "    in_channels = inp.shape[1]\n",
    "    out_channels = output_grad.shape[1]    \n",
    "\n",
    "    inp_pad = _pad_conv_input(inp, param_mid)\n",
    "    img_shape = output_grad.shape[2:]\n",
    "\n",
    "    for i in range(inp.shape[0]):\n",
    "        for c_in in range(in_channels):\n",
    "            for c_out in range(out_channels):\n",
    "                for o_w in range(img_shape[0]):\n",
    "                    for o_h in range(img_shape[1]):\n",
    "                        for p_w in range(param_size):\n",
    "                            for p_h in range(param_size):\n",
    "                                param_grad[c_in][c_out][p_w][p_h] += \\\n",
    "                                inp_pad[i][c_in][o_w+p_w][o_h+p_h] \\\n",
    "                                * output_grad[i][c_out][o_w][o_h]\n",
    "    return param_grad"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Testing gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar_imgs = np.random.randn(10, 3, 32, 32)\n",
    "cifar_param = np.random.randn(3, 16, 5, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n",
      "1\n",
      "2\n",
      "19\n",
      "\n",
      "0\n",
      "8\n",
      "0\n",
      "2\n"
     ]
    }
   ],
   "source": [
    "print(np.random.randint(0, cifar_imgs.shape[0]))\n",
    "print(np.random.randint(0, cifar_imgs.shape[1]))\n",
    "print(np.random.randint(0, cifar_imgs.shape[2]))\n",
    "print(np.random.randint(0, cifar_imgs.shape[3]))\n",
    "print()\n",
    "print(np.random.randint(0, cifar_param.shape[0]))\n",
    "print(np.random.randint(0, cifar_param.shape[1]))\n",
    "print(np.random.randint(0, cifar_param.shape[2]))\n",
    "print(np.random.randint(0, cifar_param.shape[3]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _compute_output_sum(imgs: ndarray,\n",
    "                        param: ndarray):\n",
    "    return _output(imgs, param).sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Input grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar_imgs_2 = cifar_imgs.copy()\n",
    "cifar_imgs_2[3][1][2][19] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.345298758707486"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_compute_output_sum(cifar_imgs_2, cifar_param) - _compute_output_sum(cifar_imgs, cifar_param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.3452987587074423"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_input_grad(cifar_imgs,\n",
    "            np.ones((10, 16, 32, 32)),\n",
    "            cifar_param)[3][1][2][19]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Param grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar_param_2 = cifar_param.copy()\n",
    "cifar_param_2[0][8][0][2] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-47.09123124155292"
      ]
     },
     "execution_count": 66,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_compute_output_sum(cifar_imgs, cifar_param_2) - _compute_output_sum(cifar_imgs, cifar_param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-47.0912312415532"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_param_grad(cifar_imgs,\n",
    "            np.ones((10, 16, 32, 32)),\n",
    "            cifar_param)[0][8][0][2]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
